#pragma once
#include "TensorVoter.hpp"
#include <3rdparty/ANN4Vector.hpp>
#include <GraphicsAlgo/RotationMatrix.hpp>
#include <Utility/TextProgress.hpp>
#include <Xml/XML.hpp>
#include <Math/Random.hpp>

namespace zzz{

template<int D>
class TensorVoting : public IOData
{
public:
  TensorVoting():sigma_(18.25),pred_(false)
  {
    int e=180;
    SampleNumbers[0]=0;
    SampleNumbers[1]=e;
    for (int i=2; i<D; i++) SampleNumbers[i]=SampleNumbers[i-1]*e;

    for (int i=1; i<D; i++)
    {
      SampleDirections[i].reserve(SampleNumbers[i]);
      for (int j=0; j<SampleNumbers[i]; j++)
      {
        Vector<D,double> dir=RandGen.RandOnDim(i+1);
        if (Abs(dir.Len()-1.0)<EPSILON) SampleDirections[i].push_back(dir);
        else cout<<"bad "<<dir<<endl;
      }
    }
  }

  //from voter vote voters, result in votees
  //votees are the same position as voters
  //ANN to decide which to vote
  //search scheme is hard coded, need change
  void Vote(vector<TensorVoter<D> > &voters, vector<TensorVoter<D> > &votees)
  {
    vector<int> votecount(voters.size(),0);
    vector<int> (voters.size(),0);
    votees.assign(voters.begin(),voters.end());
    
    //not clear means it vote to itself
//    for (zuint i=0; i<votees.size(); i++) votees[i].T=Matrix<D,D,double>(0.0);

    vector<Vector<D,double> > pos;
    for (zuint i=0; i<voters.size(); i++) pos.push_back(voters[i].Position);
    ANN4Vector<D,double> anntree(pos);

    TextProgress tp("%b %n",voters.size()-1,0);
    tp.ShowProgressBegin();
    for (zuint i=0; i<voters.size(); i++)
    {
      tp.ShowProgress(i);

      //this part decide which votees will receive the voting
      //may need to change for future usage
      vector<int> idx;
      vector<double> dist;
      int number=anntree.RangeQuery(voters[i].Position,10,0,idx,dist);
      anntree.RangeQuery(voters[i].Position,10,number,idx,dist);
//      anntree.Query(voters[i].Position,7,idx,dist);

      for (zuint j=0; j<idx.size(); j++)
        if (i!=idx[j]) 
        {
          votees[idx[j]].T+=GenTensorVote(voters[i],votees[idx[j]]);
          if (IsBad(votees[idx[j]].T[0])) cout<<"bad\n";
          votecount[idx[j]]++;
        }
    }
    for (zuint i=0; i<votees.size(); i++)
    {
      if (votecount[i]>0) votees[i].T/=votecount[i];
    }
    tp.ShowProgressEnd();
  }

//private:
  int SampleNumbers[D];
  Vector<D,vector<Vector<D,double> > > SampleDirections;
  RandomHyperSphere2<D,double> RandGen;
  const double sigma_;

//private:
  //just encode the StickTensor and add
  void Combine(Matrix<D,D,double> &TensorVote, const Vector<D,double> &StickVote)
  {
    //it performs tensor addition, given a stick vote
    for (int i=0; i<D; i++) for (int j=i; j<D; j++)
    {
      double more=StickVote[i]*StickVote[j];
      TensorVote(i,j)+=more;
      if (i!=j) TensorVote(j,i)+=more;
    }
  }

  //just encode the StickTensor and add
  void Combine(Matrix<D,D,double> &TensorVote, const Vector<D,double> &StickVote, double Weight)
  {
    //it performs tensor addition, given a stick vote
    for (int i=0; i<D; i++) for (int j=i; j<D; j++)
    {
      double more=Weight*StickVote[i]*StickVote[j];
      TensorVote(i,j)+=more;
      if (i!=j) TensorVote(j,i)+=more;
    }
  }

  //lookup table and interpolate
  //result should be the close to GenStickVote
  Matrix<D,D,double> GetVotePre(const Vector<D,double> &voterPosition, \
                  const Vector<D,double> &Direction, \
                  const Vector<D,double> &voteePosition,\
                  int level)
  {
    Vector<D,double> newVoteePosition=voteePosition-voterPosition;
    Vector<D,double> newNormal(0);
    newNormal[0]=1;
    Matrix<D,D,double> mat=GetRotationBetweenVectors(Direction,newNormal);
    newVoteePosition=mat*newVoteePosition;
    Matrix<D,D,double> tensor=GetPreData(level,newVoteePosition);
    Matrix<D,D,double> matt(mat);
    matt.Transpose();
    tensor=matt*tensor*mat;
    return tensor;
  }


  //voterPosition is at 0 and Direction is (1,0,...)
  //help funciton for GenStickVote2
  Vector<D,double> GenStickVote2Helper(const Vector<D,double> &voteePosition)
  {
    //a stick vote (vector) is returned.
    Vector<D,double> l=voteePosition;

    //check if voter and votee are connected by high curvature
    double llen=l.Normalize();
    double cosangle=l[0];
    double angle=SafeACos(cosangle);
    if (angle>C_PI_2) angle=Abs(angle-PI);
    if (angle<C_PI_4) return Vector<D,double>(0); //smoothness constrain violated

    //voter and votee on a straight line, or voter and votee are the same point
    //this could avoid no sphere center exist
    if (angle==C_PI_2 || voteePosition==Vector<D,double>(0))
    {
      Vector<D,double> nor(0); nor[0]=1.0;
      return nor;
    }

    //decide stick direction and length
    double sintheta=cosangle;
    double arclen=(C_PI_2-angle)*llen/sintheta;
    double phi=2*sintheta/llen;
    Vector<D, double> center(0);
    center[0]=1.0/phi;
    Vector<D, double> stick_vote=center-voteePosition;
    stick_vote.Normalize();
    
    stick_vote*=exp(- (arclen*arclen+phi*phi)/sigma_/sigma_);
    return stick_vote;
  }

  //rotate and translate to make voterPosition to 0 and Direction to (1,0,...)
  //result should be the same as GenStickVote
  Vector<D,double> GenStickVote2(const Vector<D,double> &voterPosition, \
                  const Vector<D,double> &Direction, \
                  const Vector<D,double> &voteePosition)
  {
    Vector<D,double> newVoteePosition=voteePosition-voterPosition;
    Vector<D,double> newNormal(0);
    newNormal[0]=1;
    Matrix<D,D,double> mat=GetRotationBetweenVectors(Direction,newNormal);
    newVoteePosition=mat*newVoteePosition;
    Vector<D,double> stick_vote=GenStickVote2Helper(newVoteePosition);
    //mat.Invert();
    mat.Transpose(); //for rotation matrix transpose is invert
    return mat*stick_vote;
  }

  //original version
  Vector<D,double> GenStickVote(const Vector<D,double> &voterPosition, \
                  const Vector<D,double> &Direction, \
                  const Vector<D,double> &voteePosition)
  {
    //a stick vote (vector) is returned.
    Vector<D,double> v=voteePosition-voterPosition;

    //check if voter and votee are connected by high curvature
    double llen=v.Normalize();
    double cosangle=Dot(Direction,v);
    double angle=SafeACos(cosangle);
    if (angle>C_PI_2) angle=Abs(angle-PI);
    if (angle<C_PI_4)
      return Vector<D,double>(0); //smoothness constrain violated

    //voter and votee on a straight line, or voter and votee are the same point
    //this could avoid no sphere center exist
    if (angle==C_PI_2 || voterPosition==voteePosition)
      return Direction*exp(-llen*llen/sigma_/sigma_);

    //decide stick direction and length
    double sintheta=cosangle;
    double arclen=(C_PI_2-angle)*llen/sintheta;
    double phi=2*sintheta/llen;
    Vector<D, double> stick_vote=voterPosition+Direction/phi-voteePosition;
    stick_vote.Normalize();
    
    stick_vote*=exp(- (arclen*arclen+phi*phi)/sigma_/sigma_);
    return stick_vote;
  }

  //isolated get tensor for different level
  //it can be stick tensor, n-dimentianl plate tensor or ball tensor
  //this version automatically detect precomputed data
  //if exist it prefer to use those data to computer or return directly
  //if not it will computer from raw
  Matrix<D,D,double> GetLevelTensorPre(const TensorVoter<D> &voter, const Vector<D,double> &voteePosition, int level)
  {
    if (pred_[level]) 
      return GetVotePre(voter.Position,voter.Directions.Row(level),voteePosition,level);

    if (level==0)
    {
      Matrix<D,D,double> stickTensor(0);
      Vector<D,double> vecVote=GenStickVote(voter.Position,voter.Directions.Row(0),voteePosition);
      Combine(stickTensor, vecVote);
      return stickTensor;
    }
    else if (level==D-1)
    {
      Matrix<D,D,double> ballTensor(0);
      for (int j=0; j<SampleNumbers[level]; j++)
      {
        Vector<D,double> randomDirection=SampleDirections[level][j];
        if (pred_[0])
        {
          ballTensor+=GetVotePre(voter.Position,randomDirection,voteePosition,0);
          if (IsBad(ballTensor[0]))
          {
            cout<<voter.Position<<endl;
            cout<<randomDirection<<endl;
            cout<<voteePosition<<endl;
            exit(0);
          }
        }
        else
        {
          Vector<D,double> vecVote=GenStickVote(voter.Position,randomDirection,voteePosition);
          Combine(ballTensor, vecVote);
        }
      }
      return ballTensor/SampleNumbers[level];
    }
    else
    {
      Matrix<D,D,double> plateTensor(0);
      for (int j=0; j<SampleNumbers[level]; j++)
      {
        Vector<D,double> randomDirection=voter.Directions*SampleDirections[level][j];
        double x=randomDirection.Len();
        if (pred_[0])
          plateTensor+=GetVotePre(voter.Position,randomDirection,voteePosition,0);
        else
        {
          Vector<D,double> vecVote=GenStickVote(voter.Position,randomDirection,voteePosition);
          Combine(plateTensor, vecVote);
        }
      }
      return plateTensor/SampleNumbers[level];
    }
  }

  //isolated get tensor for different level
  //it can be stick tensor, n-dimentianl plate tensor or ball tensor
  Matrix<D,D,double> GetLevelTensor(const TensorVoter<D> &voter, const Vector<D,double> &voteePosition, int level)
  {
    if (level==0)
    {
      Matrix<D,D,double> stickTensor(0);
      Vector<D,double> vecVote=GenStickVote(voter.Position,voter.Directions.Row(0),voteePosition);
      Combine(stickTensor, vecVote);
      return stickTensor;
    }
    else if (level==D-1)
    {
      Matrix<D,D,double> ballTensor(0);
      for (int j=0; j<SampleNumbers[level]; j++)
      {
        Vector<D,double> randomDirection=SampleDirections[level][j];
        Vector<D,double> vecVote=GenStickVote(voter.Position,randomDirection,voteePosition);
        Combine(ballTensor, vecVote);
      }
      return ballTensor/SampleNumbers[level];
    }
    else
    {
      Matrix<D,D,double> plateTensor(0);
      for (int j=0; j<SampleNumbers[level]; j++)
      {
        Vector<D,double> randomDirection=voter.Directions*SampleDirections[level][j];
        randomDirection=voter.Directions*randomDirection;
        double x=randomDirection.Len();
        Vector<D,double> vecVote=GenStickVote(voter.Position,randomDirection,voteePosition);
        Combine(plateTensor, vecVote);
      }
      return plateTensor/SampleNumbers[level];
    }
  }

  //one voter votes one votee
  //in this function, GenStickVote can be replace by several version, should give out the same result
  Matrix<D,D,double> GenTensorVote(const TensorVoter<D>&voter, const TensorVoter<D>&votee)
  {
    //First, the stick components of a tensor vote is computed (if direction is given). Then, all other tensor components (plates and balls) are computed, by integrating the resulting stick votes cast by a rotating stick at the voter.
    Vector<D,double> voterSaliency;
    for (int i=0; i<D-1; i++) voterSaliency[i]=voter.Lambda[i]-voter.Lambda[i+1];
    voterSaliency[D-1]=voter.Lambda[D-1];

    Matrix<D,D,double> outTensor(0);
    for (int i=0; i<D; i++)
    {
      if (voterSaliency[i]>0)
      {
        outTensor+=GetLevelTensorPre(voter,votee.Position,i)*voterSaliency[i];
      }
    }

    return outTensor;
  }

public:
  //pre compute data
  //sigma is parameter, 0.18 in the paper
  //radius is half the length of square
  //gridsize is the sample density, with radius will define the sample number
  void Precompute(double sigma, double radius, double gridsize)
  {
    TensorVoter<D> tmpvoter(Vector<D,double>(0));

    gridSize_=gridsize;
    int size=radius/gridsize;
    center_=size;
    size=size*2+1;
    Vector<D,zuint> Size(size);
    for (int i=0; i<D; i++)
    {
      if (pred_[i]) continue;
      cout<<i<<'/'<<D<<endl;
      data_[i].SetSize(Size);
      Vector<D,zuint> coord(0);

      int curpos=0;
      TextProgress tp("%b %n",data_[i].size()-1,0);
      tp.ShowProgressBegin();

      bool good=true;
      while(good)
      {
        tp.ShowProgress(curpos++);
        //cal real coord
        Vector<D,double> realcoord(coord);
        realcoord-=Vector<D,double>(center_);
        realcoord*=gridSize_;
        //sample
        data_[i](coord)=GetLevelTensorPre(tmpvoter,realcoord,i);
        //next coord
        coord[D-1]++;
        //normalize
        int cur=D-1;
        while(true)
        {
          if (coord[cur]==size)
          {
            coord[cur]=0;
            cur--;
            if (cur==-1) //overflow
            {
              good=false;
              break;
            }
            coord[cur]++;
          }
          else break;  //done
        }
      }
      tp.ShowProgressEnd();
      pred_[i]=true;
    }
  }

  //get precomputed data
  //voter at (0,0,0) normal (1,0,0)
  //n-d interpolation
  Matrix<D,D,double> GetPreData(int level, Vector<D,double> coord)
  {
    coord/=gridSize_;
    coord+=Vector<D,double>((data_[level].Size()-Vector<D,zuint>(1))/2);
    return data_[level].Interpolate(coord);
  }

  //save pre computed data
  bool SaveFile(const string &filename)
  {
    XML xml;
    XMLNode head=xml.AppendNode("TensorVotePreData");
    head<<make_pair("GridSize",ToString(gridSize_));
    head<<make_pair("Dimension",ToString(D));
    for (int i=0; i<D; i++)
    {
      XMLNode tnode=head.AppendNode(StringPrintf("TensorNode_%d",i).c_str());
      tnode<<make_pair("SampleSize",ToString(data_[i].Size()));
      string code;
      Base64Encode((char*)data_[i].Data(),data_[i].size()*sizeof(Matrix<D,D,double>),code);
      tnode<<code;
    }
    return xml.SaveFile(filename);
  }

  //load pre computed data
  bool LoadFile(const string &filename)
  {
    XML xml;
    if (!xml.LoadFile(filename)) return false;
    XMLNode head=xml.GetNode("TensorVotePreData");
    if (D!=FromString<int>(head.GetAttribute("Dimension"))) 
    {
      cout<<"Load error, Dimension dismatch!\n";
      return false;
    }
    pred_=false;
    gridSize_=FromString<double>(head.GetAttribute("GridSize"));
    for (int i=0; i<D; i++)
    {
      XMLNode tnode=head.GetNode(StringPrintf("TensorNode_%d",i).c_str());
      Vector<D,zuint> size=FromString<Vector<D,zuint> >(tnode.GetAttribute("SampleSize"));
      data_[i].SetSize(size);
      Base64Decode(tnode.GetText(),(char*)data_[i].Data(),data_[i].size()*sizeof(Matrix<D,D,double>));
      pred_[i]=true;
    }
    center_=(data_[0].Size()-Vector<D,zuint>(1))/2;
    return true;
  }

//private:
  //precomputed data_
  Vector<D, Array<D,Matrix<D,D,double> > > data_;
  Vector<D, bool> pred_;
  double gridSize_;
  Vector<D, zuint> center_;
};

}
